{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "83b7f7d5-031a-416a-8532-13a69854e0a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from haystack.document_stores import InMemoryDocumentStore\n",
    "from haystack.nodes import BM25Retriever, PromptNode, PromptTemplate, AnswerParser\n",
    "from haystack.pipelines import Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bf84458-7235-4c6c-ab76-51ad981a6fe5",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"bilgeyucel/seven-wonders\", split=\"train\")\n",
    "document_store = InMemoryDocumentStore(use_bm25=True)\n",
    "document_store.write_documents(dataset)\n",
    "retriever = BM25Retriever(document_store=document_store)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "24da5ccb-ae42-43fb-996e-7ce0a29fc1cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "rag_prompt = PromptTemplate(\n",
    "    prompt=\"\"\"Synthesize a comprehensive answer from the following text for the given question.\n",
    "                             Provide a clear and concise response that summarizes the key points and information presented in the text.\n",
    "                             Your answer should be in your own words and be no longer than 50 words.\n",
    "                             \\n\\n Related text: {join(documents)} \\n\\n Question: {query} \\n\\n Answer:\"\"\",\n",
    "    output_parser=AnswerParser(),\n",
    ")\n",
    "prompt_node = PromptNode(model_name_or_path=\"google/flan-t5-large\", default_prompt_template=rag_prompt, use_gpu=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "00513a27-082c-49bd-838b-ef12cb2e0258",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = Pipeline()\n",
    "pipe.add_node(component=retriever, name=\"retriever\", inputs=[\"Query\"])\n",
    "pipe.add_node(component=prompt_node, name=\"prompt_node\", inputs=[\"retriever\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a8452e5c-a2f5-4f77-b698-64eee323d784",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (2916 > 512). Running this sequence through the model will result in indexing errors\n",
      "The prompt has been truncated from 2916 tokens to 412 tokens so that the prompt length and answer length (100 tokens) fit within the max token limit (512 tokens). Shorten the prompt to prevent it from being cut off\n",
      "The prompt has been truncated from 2980 tokens to 412 tokens so that the prompt length and answer length (100 tokens) fit within the max token limit (512 tokens). Shorten the prompt to prevent it from being cut off\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The Great Pyramid of Giza is the oldest of the Seven Wonders of the Ancient World.\n",
      "The Hanging Gardens are the only one of the Seven Wonders for which the location has not been definitively established.\n"
     ]
    }
   ],
   "source": [
    "output = pipe.run(query=\"What is the Great Pyramid of Giza?\")\n",
    "print(output[\"answers\"][0].answer)\n",
    "output = pipe.run(query=\"Where are the hanging gardens?\")\n",
    "print(output[\"answers\"][0].answer)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchgpu",
   "language": "python",
   "name": "torchgpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
